    
    ### Doubly robust estimation
    ## Input: 
    ##      maximum likelihood estimates
    ##      Z,D,Y,X, X.phi, X.pi
    ##      Target parameter: "LATE" or "MLATE"
    ## Output: 
    ##      New estimate of alpha
    
DREst = function(Z, D, Y, X, X.phi = X, X.pi = X, target = "LATE", 
                 optimal = TRUE, ml.est){
  
  a = list()
  a = tryCatch({
    
    ### Estimate nuisance parameters
    pX = dim(X)[2]
    if(target == "LATE")    theta = tanh(X %*% ml.est[1:pX])
    if(target == "MLATE")   theta = exp(X %*% ml.est[1:pX])
    phi1      = expit(X.phi %*% ml.est[(pX+1):(2*pX)])
    phi2      = expit(X.phi %*% ml.est[(2*pX+1):(3*pX)])
    phi3      = expit(X.phi %*% ml.est[(3*pX+1):(4*pX)])
    phi4      = expit(X.phi %*% ml.est[(4*pX+1):(5*pX)])
    op        = exp(X.phi %*% ml.est[(5*pX+1):(6*pX)])
    
    
    ### Estimate instrumental propensity score
    gamma = glm(Z ~ X.pi - 1, family = binomial)$coef
    pi    = as.vector(expit(X.pi %*% gamma))
    fZ_X  = GetfZ(pi, 1-pi, Z)
    
    ### The optimal weight function is wt * X
    wt=rep(1,length(pi))
    if(optimal){
      p = getProb(theta, phi1, phi2, phi3, phi4, op, target)
      p000 = p[,1]; p001 = p[,2]; p010 = p[,3]; p011 = p[,4]
      p100 = p[,5]; p101 = p[,6]; p110 = p[,7]; p111 = p[,8]
      p1y = p011 + p111;  p0y = p010 + p110
      p1d = p101 + p111;  p0d = p100 + p110
      if(target == "LATE"){
        partial.theta.partial.alpha = 1-theta^2  #  * X
        H2_1 = p1y + theta^2 * p1d - 2 * theta * p111
        H2_0 = p0y + theta^2 * p0d - 2 * theta * p110
        H_X  = p0y - theta * p0d
        temp = (H2_1 - H_X^2) / pi + (H2_0 - H_X^2) / (1-pi)
        wt   = - partial.theta.partial.alpha * phi1 / temp
      }
      if(target == "MLATE"){
        partial.theta.partial.alpha = theta      # * X
        H2_1   = p111 / (theta^2) + p011
        H2_0   = p110 / (theta^2) + p010
        H_X    = p110 / theta + p010
        temp   = (H2_1 - H_X^2) / pi + (H2_0 - H_X^2) / (1-pi)
        f1phi1 = p111 - p110
        wt     = - partial.theta.partial.alpha * f1phi1 / (theta^2 * temp)
      }
    }
    
    
    
    ## DR estimation equation^2
    dr.objective = function(pars){
      if(target == "LATE"){
        theta = tanh(X %*% pars)
        f0  = mapply(getProbScalarRD,atanh(theta), log(op))[1,]
        H   = Y - theta * D
        H_X = f0 * phi1 + (1-phi1) * (1-phi2) * phi3 + 
          (1-phi1) * phi2 * phi4 - theta * (1-phi1) * phi2
      } 
      if(target == "MLATE"){
        theta = exp(X %*% pars)
        f0    = mapply(getProbScalarRR,log(theta), log(op))[1,]
        H     = Y * theta^{-D}
        H_X   = f0 * phi1 + (1-phi1) * phi2 * phi4 / theta + 
          (1-phi1) * (1-phi2) * phi3
      } 
      
      tmp     = t(X) %*% (wt * (2*Z - 1) * (H - H_X) / fZ_X)
      return(sum(tmp^2))  
    }
    
    dr.est   = ml.est
    startpars = ml.est[1:pX]     # pars only contain alpha
    
    condition1 = is.finite(dr.objective(startpars))
    condition2 = is.finite(dr.objective(rep(0,pX)))
    
    
    if(condition1 & condition2){
      opt1 = optim(rep(0,pX),dr.objective,control=list(maxit=10000)) 
      opt2 = optim(startpars,dr.objective,control=list(maxit=10000)) 
      
      dr.est[1:pX] = opt1$par
      if(opt1$value > opt2$value) dr.est[1:pX] = opt2$par
      
      indicator = max(abs(dr.est[1:pX]))
    }
    if(condition1 & (!condition2)){
      opt1 = optim(rep(0,pX),dr.objective,control=list(maxit=10000)) 
      dr.est   = ml.est;  dr.est[1:pX] = opt1$par
      indicator = max(abs(dr.est[1:pX]))
    }
    if ((!condition1) & condition2){
      opt2 = optim(startpars,dr.objective,control=list(maxit=10000)) 
      dr.est   = ml.est;  dr.est[1:pX] = opt2$par
      indicator = max(abs(dr.est[1:pX]))
    }
    if ((!condition1) & (!condition2)){
      indicator = 100
    }
    
    
    step = 0
    while (indicator>10 & step < 50){
      step = step + 1
      set.seed(step)
      initial = runif(pX, -1, 1)
      opt1 = optim(initial, dr.objective, control=list(maxit=10000))
      dr.est[1:pX] = opt1$par
      indicator = max(abs(dr.est[1:pX]))
    }
    
    
    list(val = dr.est[1:pX], error = NA)
  },
  
  error = function(e) { 
    list(val = rep(NA, pX), error = e)
  })
  
  return(a)
  
}

    
    
    